package com.zusmart.base.network.nio;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.SocketChannel;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;

import com.zusmart.base.buffer.Buffer;
import com.zusmart.base.buffer.BufferAllocator;
import com.zusmart.base.future.FutureListener;
import com.zusmart.base.network.ChannelContextFuture;
import com.zusmart.base.network.ChannelContextManager;
import com.zusmart.base.network.ChannelOption;
import com.zusmart.base.network.ChannelUtils;
import com.zusmart.base.network.message.Message;
import com.zusmart.base.network.message.MessageFuture;
import com.zusmart.base.network.message.MessageProtocol;
import com.zusmart.base.network.support.AbstractChannelContext;
import com.zusmart.base.toolkit.NetAddress;
import com.zusmart.base.util.Assert;

public class NioChannelContext extends AbstractChannelContext {

	private static final int EVENT_READABLE = SelectionKey.OP_READ;
	private static final int EVENT_WRITABLE = SelectionKey.OP_WRITE;

	private SelectionKey selectionKey;
	private SocketChannel socketChannel;
	private ChannelOption channelOption;
	private BufferAllocator bufferAllocator;
	private MessageProtocol messageProtocol;
	private NioChannelEventLoop channelEventLoop;

	private NetAddress serverAddress;
	private NetAddress clientAddress;

	private Buffer readerBuffer;
	private Queue<MessageFuture> writerMessage = new ConcurrentLinkedQueue<MessageFuture>();

	private ChannelContextFuture closeFuture;

	public NioChannelContext(boolean serverSide, String contextCode, SocketChannel socketChannel, ChannelContextManager contextManager, MessageProtocol messageProtocol, NioChannelEventLoop channelEventLoop) {
		super(serverSide, contextCode, contextManager);
		Assert.isNull(socketChannel, "socket channel must not be null");
		Assert.isNull(channelEventLoop, "channel event loop must not be null");
		Assert.isNull(messageProtocol, "message protocol must not be null");
		this.socketChannel = socketChannel;
		this.messageProtocol = messageProtocol;
		this.bufferAllocator = channelEventLoop.getBufferAllocator();
		this.channelOption = channelEventLoop.getChannelOption();
		this.channelEventLoop = channelEventLoop;
	}

	@Override
	public boolean isOpen() {
		return null != this.socketChannel && this.socketChannel.isOpen();
	}

	@Override
	public NetAddress getServerAddress() {
		return this.serverAddress;
	}

	@Override
	public NetAddress getClientAddress() {
		return this.clientAddress;
	}

	@Override
	public ChannelContextFuture flush() {
		final ChannelContextFuture future = new ChannelContextFuture(this);
		if (this.channelEventLoop.inEventLoop()) {
			try {
				this.doListener(EVENT_WRITABLE);
				future.setSuccess();
			} catch (IOException e) {
				this.doException(e);
				future.setFailure(e);
			}
		} else {
			this.channelEventLoop.doListener(this, EVENT_WRITABLE).attachListener(new FutureListener<ChannelContextFuture>() {
				@Override
				public void execute(ChannelContextFuture f) throws Exception {
					if (f.isSuccessed()) {
						future.setSuccess();
					} else {
						future.setFailure(f.getFailureCause());
					}
				}
			});
		}
		return future;
	}

	@Override
	public MessageFuture write(Message message) {
		MessageFuture future = new MessageFuture(this, message);
		if (null == message) {
			future.setFailure(new RuntimeException("message is null"));
		} else {
			this.writerMessage.add(future);
		}
		return future;
	}

	public NioChannelEventLoop getNioChannelEventLoop() {
		return this.channelEventLoop;
	}

	/////////////////////////////////////////////////////////////////////////////

	@Override
	public void doRegister(Selector selector) throws IOException {
		this.socketChannel.configureBlocking(false);
		this.selectionKey = this.socketChannel.register(selector, EVENT_READABLE, this);
		this.updateAddress();
		this.updateOnRegist();
		this.getChannelContextHandlerChain().fireOnRegister();
	}

	@Override
	public void doListener(int events) throws IOException {
		this.selectionKey.interestOps(events);
	}

	@Override
	public void deRegister() {
		this.shutdown();
	}

	@Override
	public void doException(Throwable cause) {
		this.getChannelContextHandlerChain().fireOnException(cause);
	}

	@Override
	public void doTimeout() {
		this.getChannelContextHandlerChain().fireOnTimeout();
	}

	@Override
	public void doReadable() {
		int bufferSize = this.channelOption.getBufferSize();
		boolean directBuffer = this.channelOption.isDirectBuffer();
		Buffer buffer = this.bufferAllocator.allocate(bufferSize, directBuffer);
		int total = 0;
		int limit = -1;
		try {
			while ((limit = buffer.read(this.socketChannel)) >= 0) {
				if (limit == 0) {
					break;
				}
				total += limit;
			}
			this.updateOnReader(total);
		} catch (IOException e) {
			buffer.release();
			this.deRegister();
			return;
		} catch (Exception e) {
			buffer.release();
			this.doException(e);
			return;
		}
		if (total > 0) {
			buffer.flip();
			if (null == this.readerBuffer) {
				this.readerBuffer = buffer;
			} else {
				if (this.readerBuffer.remaining() < buffer.remaining()) {
					Buffer newBuffer = this.bufferAllocator.allocate(this.readerBuffer.position() + buffer.remaining(), buffer.isDirect());
					newBuffer.put(this.readerBuffer.flip());
					this.readerBuffer.release();
					this.readerBuffer = newBuffer;
				}
				this.readerBuffer.put(buffer);
				this.readerBuffer.flip();
				buffer.release();
			}
			try {
				this.getChannelContextHandlerChain().fireOnReader(this.readerBuffer);
			} catch (Exception e) {
				this.doException(e);
			} finally {
				if (this.readerBuffer.hasRemaining()) {
					this.readerBuffer.compact();
				} else {
					this.readerBuffer.release();
					this.readerBuffer = null;
				}
			}
		}
		if (limit < 0) {
			this.deRegister();
		}
	}

	@Override
	public void doWritable() throws IOException {
		if (this.writerMessage.size() > 0) {
			MessageFuture future = null;
			while ((future = this.writerMessage.poll()) != null) {
				Buffer buffer = null;
				try {
					buffer = this.messageProtocol.encode(future.getMessage());
					if (null == buffer) {
						continue;
					}
					buffer.flip();
					while (buffer.hasRemaining()) {
						this.updateOnWriter(buffer.write(this.socketChannel));
					}
					future.setSuccess();
				} catch (Exception e) {
					this.doException(e);
					future.setFailure(e);
				} finally {
					if (null != buffer) {
						buffer.release();
					}
				}
			}
		}
		if (this.writerMessage.size() == 0) {
			if (this.isClosing()) {
				this.shutdown();
			} else {
				this.doListener(EVENT_READABLE);
			}
		} else {
			this.doWritable();
		}
	}

	@Override
	public ChannelContextFuture close() {
		if (this.isClosing()) {
			return this.closeFuture;
		}
		this.closeFuture = new ChannelContextFuture(this);
		this.shutdown();
		return this.closeFuture;
	}

	protected boolean isClosing() {
		return null != this.closeFuture;
	}

	protected synchronized void shutdown() {
		if (this.isOpen()) {
			ChannelUtils.cancel(this.selectionKey);
			ChannelUtils.close(this.socketChannel);
			this.getChannelContextHandlerChain().fireUnRegister();
		}
		if (this.isClosing()) {
			this.closeFuture.setSuccess();
		}
	}

	protected void updateAddress() {
		if (this.isServerSide()) {
			this.serverAddress = NetAddress.create((InetSocketAddress) this.socketChannel.socket().getLocalSocketAddress());
			this.clientAddress = NetAddress.create((InetSocketAddress) this.socketChannel.socket().getRemoteSocketAddress());
		} else {
			this.clientAddress = NetAddress.create((InetSocketAddress) this.socketChannel.socket().getLocalSocketAddress());
			this.serverAddress = NetAddress.create((InetSocketAddress) this.socketChannel.socket().getRemoteSocketAddress());
		}
	}

}